{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xf3lVTZYhbzA"
   },
   "source": [
    "# Initial Setups"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2ORFXeezn5Og"
   },
   "source": [
    "## (Google Colab use only)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 3570,
     "status": "ok",
     "timestamp": 1620418927808,
     "user": {
      "displayName": "Ronald Seoh",
      "photoUrl": "",
      "userId": "10284188050297676522"
     },
     "user_tz": 240
    },
    "id": "YFAQ6IgXn8FK",
    "outputId": "25f6ccd2-93f3-4714-9551-e47ee5916705"
   },
   "outputs": [],
   "source": [
    "# Use Google Colab\n",
    "use_colab = False\n",
    "\n",
    "# Is this notebook running on Colab?\n",
    "# If so, then google.colab package (github.com/googlecolab/colabtools)\n",
    "# should be available in this environment\n",
    "\n",
    "# Previous version used importlib, but we could do the same thing with\n",
    "# just attempting to import google.colab\n",
    "try:\n",
    "    from google.colab import drive\n",
    "    colab_available = True\n",
    "except:\n",
    "    colab_available = False\n",
    "\n",
    "if use_colab and colab_available:\n",
    "    drive.mount('/content/drive')\n",
    "\n",
    "    # cd to the appropriate working directory under my Google Drive\n",
    "    %cd '/content/drive/My Drive/cs696ds_lexalytics/Ronald Gypsum Prompts'\n",
    "    \n",
    "    # Install packages specified in requirements\n",
    "    !pip install -r requirements.txt\n",
    "\n",
    "    %cd 'prompts_subtask4'\n",
    "    \n",
    "    # List the directory contents\n",
    "    !ls"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tgzsHF7Zhbzo"
   },
   "source": [
    "## Experiment parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 3559,
     "status": "ok",
     "timestamp": 1620418927816,
     "user": {
      "displayName": "Ronald Seoh",
      "photoUrl": "",
      "userId": "10284188050297676522"
     },
     "user_tz": 240
    },
    "id": "DUpGBmOJhbzs",
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "# We will use the following string ID to identify this particular (training) experiments\n",
    "# in directory paths and other settings\n",
    "experiment_id = 'bert_16_shot_prompt_logit_softmax_atsc_restaurants_bert_yelp_restaurants_multiple_prompts_589'\n",
    "\n",
    "# Random seed\n",
    "random_seed = 589\n",
    "\n",
    "# path to pretrained MLM model folder or the string \"bert-base-uncased\"\n",
    "lm_model_path = 'bert-base-uncased'\n",
    "\n",
    "# Prompts to be added to the end of each review text\n",
    "# Note: pseudo-labels for each prompt should be given in the order of (positive), (negative), (neutral)\n",
    "sentiment_prompts = [\n",
    "    {\"prompt\": \"I felt the {aspect} was [MASK].\", \"labels\": [\"good\", \"bad\", \"ok\"]},\n",
    "    {\"prompt\": \"I [MASK] the {aspect}.\", \"labels\": [\"love\", \"hate\", \"dislike\"]},\n",
    "    {\"prompt\": \"The {aspect} made me feel [MASK].\", \"labels\": [\"good\", \"bad\", \"indifferent\"]},\n",
    "    {\"prompt\": \"The {aspect} is [MASK].\", \"labels\": [\"good\", \"bad\", \"ok\"]}\n",
    "]\n",
    "\n",
    "# Multiple prompt merging behavior\n",
    "prompts_merge_behavior = 'sum_logits'\n",
    "\n",
    "# Perturb the input embeddings of tokens within the prompts\n",
    "prompts_perturb = False\n",
    "\n",
    "# No ATSC training\n",
    "no_atsc_training = False\n",
    "\n",
    "# Test settings\n",
    "testing_batch_size = 32\n",
    "testing_domain = 'restaurants' # 'laptops', 'restaurants', 'joint'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 3552,
     "status": "ok",
     "timestamp": 1620418927819,
     "user": {
      "displayName": "Ronald Seoh",
      "photoUrl": "",
      "userId": "10284188050297676522"
     },
     "user_tz": 240
    },
    "id": "AtJhBPXMY36f"
   },
   "outputs": [],
   "source": [
    "# Batch size adjustment for multiple prompts.\n",
    "testing_batch_size = testing_batch_size // len(sentiment_prompts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 3546,
     "status": "ok",
     "timestamp": 1620418927824,
     "user": {
      "displayName": "Ronald Seoh",
      "photoUrl": "",
      "userId": "10284188050297676522"
     },
     "user_tz": 240
    },
    "id": "keCSh__SY36i",
    "outputId": "7d83760a-ac16-481c-c9e4-6633f493b37d"
   },
   "outputs": [],
   "source": [
    "print(\"Experiment ID:\", experiment_id)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GYZesqTioMvF"
   },
   "source": [
    "## Package imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 5595,
     "status": "ok",
     "timestamp": 1620418929888,
     "user": {
      "displayName": "Ronald Seoh",
      "photoUrl": "",
      "userId": "10284188050297676522"
     },
     "user_tz": 240
    },
    "id": "MlK_-DrWhbzb",
    "outputId": "5854fa11-ce1f-49a2-a493-6d6b1fb92423"
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "import random\n",
    "import shutil\n",
    "import copy\n",
    "import inspect\n",
    "import json\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import transformers\n",
    "import datasets\n",
    "import sklearn.metrics\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sn\n",
    "import tqdm\n",
    "\n",
    "current_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))\n",
    "parent_dir = os.path.dirname(current_dir)\n",
    "sys.path.append(parent_dir)\n",
    "\n",
    "import utils\n",
    "\n",
    "# Random seed settings\n",
    "random.seed(random_seed)\n",
    "np.random.seed(random_seed)\n",
    "\n",
    "# cuBLAS reproducibility\n",
    "# https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility\n",
    "os.environ['CUBLAS_WORKSPACE_CONFIG'] = \":4096:8\"\n",
    "torch.set_deterministic(True)\n",
    "torch.manual_seed(random_seed)\n",
    "\n",
    "# Print version information\n",
    "print(\"Python version: \" + sys.version)\n",
    "print(\"NumPy version: \" + np.__version__)\n",
    "print(\"PyTorch version: \" + torch.__version__)\n",
    "print(\"Transformers version: \" + transformers.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UWuR30eUoTWP"
   },
   "source": [
    "## PyTorch GPU settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 5585,
     "status": "ok",
     "timestamp": 1620418929892,
     "user": {
      "displayName": "Ronald Seoh",
      "photoUrl": "",
      "userId": "10284188050297676522"
     },
     "user_tz": 240
    },
    "id": "PfNlm-ykoSlM",
    "outputId": "9b7cc30c-e6ae-404d-ce4f-b2afbdbec29d"
   },
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():    \n",
    "    torch_device = torch.device('cuda')\n",
    "\n",
    "    # Set this to True to make your output immediately reproducible\n",
    "    # Note: https://pytorch.org/docs/stable/notes/randomness.html\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    \n",
    "    # Disable 'benchmark' mode: Set this False if you want to measure running times more fairly\n",
    "    # Note: https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "    \n",
    "    # Faster Host to GPU copies with page-locked memory\n",
    "    use_pin_memory = True\n",
    "    \n",
    "    # Number of compute devices to be used for training\n",
    "    training_device_count = torch.cuda.device_count()\n",
    "\n",
    "    # CUDA libraries version information\n",
    "    print(\"CUDA Version: \" + str(torch.version.cuda))\n",
    "    print(\"cuDNN Version: \" + str(torch.backends.cudnn.version()))\n",
    "    print(\"CUDA Device Name: \" + str(torch.cuda.get_device_name()))\n",
    "    print(\"CUDA Capabilities: \"+ str(torch.cuda.get_device_capability()))\n",
    "    print(\"Number of CUDA devices: \"+ str(training_device_count))\n",
    "    \n",
    "else:\n",
    "    torch_device = torch.device('cpu')\n",
    "    use_pin_memory = False\n",
    "    \n",
    "    # Number of compute devices to be used for training\n",
    "    training_device_count = 1\n",
    "\n",
    "print()\n",
    "print(\"PyTorch device selected:\", torch_device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ayX5VRLfocFk"
   },
   "source": [
    "# Prepare Datasets for Prompt-based Classifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "U9LAAJP-hbz7"
   },
   "source": [
    "## Load the SemEval dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 5575,
     "status": "ok",
     "timestamp": 1620418929895,
     "user": {
      "displayName": "Ronald Seoh",
      "photoUrl": "",
      "userId": "10284188050297676522"
     },
     "user_tz": 240
    },
    "id": "gpL2uHPUhbz9",
    "outputId": "41504d48-f3c5-4361-a055-e24d3046f9c8"
   },
   "outputs": [],
   "source": [
    "# Load semeval for both domains\n",
    "restaurants_dataset = datasets.load_dataset(\n",
    "    os.path.abspath('../dataset_scripts/semeval2014_task4/semeval2014_task4.py'),\n",
    "    name=\"SemEval2014Task4Dataset - Subtask 4\",\n",
    "    data_files={\n",
    "        'test': '../dataset_files/semeval_2014/Restaurants_Test_Gold.xml',\n",
    "        'train': '../dataset_files/semeval_2014/Restaurants_Train_v2.xml',\n",
    "    },\n",
    "    cache_dir='../dataset_cache')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 5565,
     "status": "ok",
     "timestamp": 1620418929898,
     "user": {
      "displayName": "Ronald Seoh",
      "photoUrl": "",
      "userId": "10284188050297676522"
     },
     "user_tz": 240
    },
    "id": "Gi5m8AbPj1iJ"
   },
   "outputs": [],
   "source": [
    "# The dataset chosen for testing\n",
    "if testing_domain == 'restaurants':\n",
    "    test_set = restaurants_dataset['test']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 5557,
     "status": "ok",
     "timestamp": 1620418929901,
     "user": {
      "displayName": "Ronald Seoh",
      "photoUrl": "",
      "userId": "10284188050297676522"
     },
     "user_tz": 240
    },
    "id": "Est9ao9rcH4l",
    "outputId": "8487f46b-593a-43f6-b98c-eb559af8d169"
   },
   "outputs": [],
   "source": [
    "print(len(test_set))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 5547,
     "status": "ok",
     "timestamp": 1620418929903,
     "user": {
      "displayName": "Ronald Seoh",
      "photoUrl": "",
      "userId": "10284188050297676522"
     },
     "user_tz": 240
    },
    "id": "_npZeCIqcKjT",
    "outputId": "8951acb2-32ca-4e07-e1cf-011aa6831a69"
   },
   "outputs": [],
   "source": [
    "print(test_set[4])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6TOMmAtIvoZ_"
   },
   "source": [
    "# Zero-shot ATSC with Prompts"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3jNAtuv-hbzv"
   },
   "source": [
    "## Initialize BERT MLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 9973,
     "status": "ok",
     "timestamp": 1620418934341,
     "user": {
      "displayName": "Ronald Seoh",
      "photoUrl": "",
      "userId": "10284188050297676522"
     },
     "user_tz": 240
    },
    "id": "En2BmfjVhbzy"
   },
   "outputs": [],
   "source": [
    "if no_atsc_training:\n",
    "    lm = transformers.AutoModelForMaskedLM.from_pretrained(lm_model_path)\n",
    "else:\n",
    "    # Load pretrained language model\n",
    "    # Use the config, but not the actual weights, from the pretrained model\n",
    "    lm_config = transformers.AutoConfig.from_pretrained('bert-base-uncased', cache_dir='../bert_base_cache')\n",
    "    lm = transformers.AutoModelForMaskedLM.from_config(lm_config)\n",
    "\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained('bert-base-uncased', cache_dir='../bert_base_cache')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "TEIbN5Xthb0o"
   },
   "source": [
    "## Define a new model with non-trainable softmax head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 12519,
     "status": "ok",
     "timestamp": 1620418936896,
     "user": {
      "displayName": "Ronald Seoh",
      "photoUrl": "",
      "userId": "10284188050297676522"
     },
     "user_tz": 240
    },
    "id": "wN3q4Rsopxby",
    "outputId": "f9336495-e129-4ca0-d311-6bc3c2f38c20"
   },
   "outputs": [],
   "source": [
    "# Encode the pseudo-label words for each sentiment class\n",
    "sentiment_word_ids = []\n",
    "\n",
    "for sp in sentiment_prompts:\n",
    "    sentiment_word_ids.append(\n",
    "        [tokenizer.convert_tokens_to_ids(w) for w in sp['labels']])\n",
    "\n",
    "print(sentiment_word_ids)\n",
    "\n",
    "classifier_model = utils.MultiPromptLogitSentimentClassificationHead(\n",
    "    lm=lm,\n",
    "    num_class=3,\n",
    "    num_prompts=len(sentiment_prompts), pseudo_label_words=sentiment_word_ids,\n",
    "    target_token_id=tokenizer.mask_token_id,\n",
    "    merge_behavior=prompts_merge_behavior,\n",
    "    perturb_prompts=prompts_perturb)\n",
    "\n",
    "classifier_model = classifier_model.to(device=torch_device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1U6B5GNSYBYk"
   },
   "source": [
    "## Load our saved weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 14206,
     "status": "ok",
     "timestamp": 1620418938595,
     "user": {
      "displayName": "Ronald Seoh",
      "photoUrl": "",
      "userId": "10284188050297676522"
     },
     "user_tz": 240
    },
    "id": "gLo25nUcYBGx",
    "outputId": "f5fef2db-f954-4d65-b4c9-c8e93a269b5d"
   },
   "outputs": [],
   "source": [
    "# Locate the weight file.\n",
    "trained_model_directory = os.path.join('..', 'trained_models_prompts', experiment_id)\n",
    "\n",
    "if no_atsc_training:\n",
    "    os.makedirs(trained_model_directory, exist_ok=True)\n",
    "\n",
    "if not no_atsc_training:\n",
    "    saved_weights_name = ''\n",
    "\n",
    "    for fname in os.listdir(trained_model_directory):\n",
    "        if fname.startswith('epoch'):\n",
    "            saved_weights_name = fname\n",
    "            break\n",
    "\n",
    "    print(\"Loading\", saved_weights_name)\n",
    "\n",
    "    classifier_model.load_state_dict(torch.load(\n",
    "        os.path.join(trained_model_directory, saved_weights_name),\n",
    "        map_location=torch_device))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1l1H_XIPhb0y"
   },
   "source": [
    "## Evaluation with in-domain test set\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 14198,
     "status": "ok",
     "timestamp": 1620418938597,
     "user": {
      "displayName": "Ronald Seoh",
      "photoUrl": "",
      "userId": "10284188050297676522"
     },
     "user_tz": 240
    },
    "id": "0S80DoYrqApi"
   },
   "outputs": [],
   "source": [
    "def compute_metrics(predictions, labels):\n",
    "    preds = predictions.argmax(-1)\n",
    "\n",
    "    precision, recall, f1, _ = sklearn.metrics.precision_recall_fscore_support(\n",
    "        y_true=labels, y_pred=preds, labels=[0,1,2], average='macro')\n",
    "\n",
    "    acc = sklearn.metrics.accuracy_score(labels, preds)\n",
    "\n",
    "    return {\n",
    "        'accuracy': acc,\n",
    "        'f1': f1,\n",
    "        'precision': precision,\n",
    "        'recall': recall\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 14193,
     "status": "ok",
     "timestamp": 1620418938599,
     "user": {
      "displayName": "Ronald Seoh",
      "photoUrl": "",
      "userId": "10284188050297676522"
     },
     "user_tz": 240
    },
    "id": "9NXoBTs5h2eO"
   },
   "outputs": [],
   "source": [
    "test_dataloader = torch.utils.data.DataLoader(\n",
    "    test_set, batch_size=testing_batch_size, pin_memory=use_pin_memory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 104,
     "referenced_widgets": [
      "8bc7129ea76340a494cc7916b037c052",
      "5704871af06f4221b051f07f65c7e74a",
      "1ee6f145ba0e4da09d2d31bf1967da8d",
      "ce70f9442d0240118f9f76c3a485e383",
      "65ff039fa4f94365bd0d070c5f22be61",
      "5f1b30eb632447ccad493dc94bb2e168",
      "ac65a5f115be49628880b6582b47d873",
      "31348e5edcad45d4a8559cdd6b677ee2"
     ]
    },
    "executionInfo": {
     "elapsed": 95055,
     "status": "ok",
     "timestamp": 1620419019471,
     "user": {
      "displayName": "Ronald Seoh",
      "photoUrl": "",
      "userId": "10284188050297676522"
     },
     "user_tz": 240
    },
    "id": "LLcc_wZjhb0y",
    "outputId": "16de918b-7858-4cf6-d7b5-5194affd2151"
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    classifier_model.eval()\n",
    "\n",
    "    predictions_test = torch.Tensor([])\n",
    "    labels_test = torch.Tensor([])\n",
    "\n",
    "    for batch_val in tqdm.notebook.tqdm(test_dataloader):\n",
    "\n",
    "        reviews_repeated = []\n",
    "        prompts_populated = []\n",
    "\n",
    "        for prompt in sentiment_prompts:\n",
    "            reviews_repeated = reviews_repeated + batch_val[\"text\"]\n",
    "\n",
    "            for aspect in batch_val[\"aspect\"]:\n",
    "                prompts_populated.append(prompt['prompt'].format(aspect=aspect))\n",
    "\n",
    "        batch_encoded = tokenizer(\n",
    "            reviews_repeated, prompts_populated,\n",
    "            padding='max_length', truncation='only_first', max_length=256,\n",
    "            return_tensors='pt')\n",
    "        \n",
    "        batch_encoded.to(torch_device)\n",
    "\n",
    "        labels = batch_val[\"sentiment\"]\n",
    "\n",
    "        outputs = classifier_model(batch_encoded)\n",
    "\n",
    "        outputs = outputs.to('cpu')\n",
    "\n",
    "        predictions_test = torch.cat([predictions_test, outputs])\n",
    "        labels_test = torch.cat([labels_test, labels])\n",
    "\n",
    "    # Compute metrics\n",
    "    test_metrics = compute_metrics(predictions_test, labels_test)\n",
    "\n",
    "    print(test_metrics)\n",
    "\n",
    "    # Save test_metrics into a file for later processing\n",
    "    with open(os.path.join(trained_model_directory, 'test_metrics_subtask4.json'), 'w') as test_metrics_json:\n",
    "        json.dump(test_metrics, test_metrics_json)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HjpA_0m1hb08"
   },
   "source": [
    "## Results visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 442
    },
    "executionInfo": {
     "elapsed": 95048,
     "status": "ok",
     "timestamp": 1620419019476,
     "user": {
      "displayName": "Ronald Seoh",
      "photoUrl": "",
      "userId": "10284188050297676522"
     },
     "user_tz": 240
    },
    "id": "w9G9AUeQhb09",
    "outputId": "06eb5449-2881-43c4-98b5-ad2cee1f476f"
   },
   "outputs": [],
   "source": [
    "# Calculate metrics and confusion matrix based upon predictions and true labels\n",
    "cm = sklearn.metrics.confusion_matrix(labels_test.detach().numpy(), predictions_test.detach().numpy().argmax(-1))\n",
    "\n",
    "df_cm = pd.DataFrame(\n",
    "    cm,\n",
    "    index=[i for i in [\"positive\", \"negative\", \"neutral\"]],\n",
    "    columns=[i for i in [\"positive\", \"negative\", \"neutral\"]])\n",
    "\n",
    "plt.figure(figsize=(10, 7))\n",
    "\n",
    "ax = sn.heatmap(df_cm, annot=True)\n",
    "\n",
    "ax.set(xlabel='Predicted Label', ylabel='True Label')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 95047,
     "status": "ok",
     "timestamp": 1620419019482,
     "user": {
      "displayName": "Ronald Seoh",
      "photoUrl": "",
      "userId": "10284188050297676522"
     },
     "user_tz": 240
    },
    "id": "4Wo_Yk0LY37d"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "celltoolbar": "Tags",
  "colab": {
   "collapsed_sections": [],
   "name": "bert_zero_shot_prompt_logit_softmax_subtask4.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "1ee6f145ba0e4da09d2d31bf1967da8d": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "100%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_5f1b30eb632447ccad493dc94bb2e168",
      "max": 122,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_65ff039fa4f94365bd0d070c5f22be61",
      "value": 122
     }
    },
    "31348e5edcad45d4a8559cdd6b677ee2": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "5704871af06f4221b051f07f65c7e74a": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "5f1b30eb632447ccad493dc94bb2e168": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "65ff039fa4f94365bd0d070c5f22be61": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "8bc7129ea76340a494cc7916b037c052": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_1ee6f145ba0e4da09d2d31bf1967da8d",
       "IPY_MODEL_ce70f9442d0240118f9f76c3a485e383"
      ],
      "layout": "IPY_MODEL_5704871af06f4221b051f07f65c7e74a"
     }
    },
    "ac65a5f115be49628880b6582b47d873": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "ce70f9442d0240118f9f76c3a485e383": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_31348e5edcad45d4a8559cdd6b677ee2",
      "placeholder": "",
      "style": "IPY_MODEL_ac65a5f115be49628880b6582b47d873",
      "value": " 122/122 [01:21&lt;00:00,  1.50it/s]"
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}